import argparse
import fnmatch
import re
import time
from multiprocessing import Pool

import numpy as np

from datasets.all_datasets import get_all_dataset_names, select_datasets

from experiments.experimental_pipeline import DEFAULT_RESULTS, get_default_regularization
from src.influence_functions import compute_model_influences
from datasets.load_datasets import load_dataset
from src.logistic_regression import RegularizationType, LogisticRegression
from src.shapley_values import (
    estimate_normalized_shapley_with_error,
    influence_shapley_estimate, single_monte_carlo_shapley_estimate
)


def run_single(args):
    dataset, permutation, regularization = args
    phi, weights = single_monte_carlo_shapley_estimate(
        dataset,
        permutation,
        regularization=regularization,
        reg_type=RegularizationType.L2,
        use_tqdm=False
    )
    return phi, permutation, weights


def run_experiment(dataset_name, regularization=0.1, k=10, seed=42, verbosity=2,
                   load_cache=False, max_train=None, max_test=None, num_processes=2):
    print(f"Loading dataset: {dataset_name}")
    dataset = load_dataset(dataset_name)
    print(f"{dataset.test.features.shape=}, {dataset.train.features.shape=}")
    dataset = dataset.subsample(max_train=max_train, max_test=max_test, seed=seed)

    # Build suffix
    suffix_parts = []
    if max_train is not None:
        suffix_parts.append(f"train{max_train}")
    if max_test is not None:
        suffix_parts.append(f"test{max_test}")
    suffix = ("_" + "_".join(suffix_parts)) if suffix_parts else ""

    output_file = DEFAULT_RESULTS / f"{dataset_name}{suffix}_shapley_values.npz"
    cache = {}
    if load_cache and output_file.exists():
        print(f"Loading cached results from {output_file}")
        cache = dict(np.load(output_file, allow_pickle=True))
        existing_mc_runs = cache.get("mc_runs", np.empty((0, dataset.train.features.shape[0])))
    else:
        existing_mc_runs = np.empty((0, dataset.train.features.shape[0]))

    # mc_permutations = cache.get("mc_permutations", np.empty((0, dataset.train.features.shape[0]), dtype=int))
    # mc_weights = cache.get("mc_weights",
    #                        np.empty((0, dataset.train.features.shape[0], dataset.train.features.shape[1])))

    if_phi = cache.get("if_phi", None)
    rif_phi = cache.get("rif_phi", None)

    print("Training logistic regression model...")
    regression = LogisticRegression(
        dataset.train.features, dataset.train.labels,
        regularization=regularization,
        reg_type=RegularizationType.L2
    )
    regression.fit(verbose=(verbosity >= 1))

    if if_phi is None or rif_phi is None:
        print("Computing influence functions...")
        influences = compute_model_influences(
            regression=regression,
            experiment=dataset,
            verbose=verbosity >= 2
        )

        if if_phi is None:
            print("Computing IF-based Shapley estimate...")
            if_phi = influence_shapley_estimate(
                dataset=dataset,
                model=regression.model,
                influences=influences.influence_scores
            )

        if rif_phi is None:
            print("Computing rescaled IF-based Shapley estimate...")
            rif_phi = influence_shapley_estimate(
                dataset=dataset,
                model=regression.model,
                influences=influences.rescaled_influence_scores
            )

        print(f"Saving IF/RIF to {output_file}")
        np.savez_compressed(output_file, **cache, if_phi=if_phi, rif_phi=rif_phi, mc_runs=existing_mc_runs)

    # Monte Carlo loop with incremental saving
    print("Running Monte Carlo Shapley value estimation in parallel...")
    n = dataset.train.features.shape[0]
    total_runs = existing_mc_runs.shape[0]
    remaining_runs = k - total_runs

    if remaining_runs <= 0:
        print("No more MC runs needed.")
        return

    # Pre-generate permutations
    np.random.seed(seed)

    def arg_generator():
        for _ in range(remaining_runs):
            yield dataset, np.random.permutation(n), regularization

    with Pool(processes=num_processes) as pool:
        for idx, (phi, permutation, weights) in enumerate(pool.imap_unordered(run_single, arg_generator()),
                                                          start=total_runs + 1):
            existing_mc_runs = np.vstack([existing_mc_runs, phi])
            # mc_permutations = np.vstack([mc_permutations, permutation[None, :]])

            mc_phi, mc_err = estimate_normalized_shapley_with_error(existing_mc_runs)

            print(f"Saving results after completed run {idx}")
            np.savez_compressed(
                output_file,
                **cache,
                mc_runs=existing_mc_runs,
                mc_phi=mc_phi,
                mc_err=mc_err,
                if_phi=if_phi,
                rif_phi=rif_phi
            )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run experiments for selected datasets.")
    parser.add_argument("--dataset", type=str, help="Dataset name or regex pattern")
    parser.add_argument("--regex", action="store_true", help="Interpret dataset pattern as regex")
    parser.add_argument("--list", action="store_true", help="List all dataset names and exit")
    parser.add_argument("--k", type=int, default=10, help="Number of Monte Carlo runs")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
    parser.add_argument("--verbosity", type=int, default=2, help="Verbosity level")
    parser.add_argument("--load_cache", action="store_true", help="Load and extend existing results if available")
    parser.add_argument("--max_train", type=int, default=None, help="Maximum number of training samples to use")
    parser.add_argument("--max_test", type=int, default=None, help="Maximum number of test samples to use")
    parser.add_argument("--num_processes", type=int, default=2,
                        help="Number of parallel processes for Monte Carlo runs")

    args = parser.parse_args()
    all_datasets = get_all_dataset_names()

    if args.list:
        print("Available datasets:")
        for name in all_datasets:
            print(" -", name)
        exit(0)

    if not args.dataset:
        parser.error("the following arguments are required: --dataset (unless using --list)")

    if args.regex:
        regex = re.compile(args.dataset, re.IGNORECASE)
    else:
        pattern = fnmatch.translate(args.dataset)
        regex = re.compile(pattern, re.IGNORECASE)

    selected_datasets = select_datasets(args.dataset, args.regex)

    if not selected_datasets:
        print(f"No datasets matched pattern: {args.dataset}")
        exit(1)
    else:
        print(f"Selected datasets: {selected_datasets}")

    print(f"{args.max_train=}, {args.max_test=}")

    for dataset_name in selected_datasets:
        print("\n" + "=" * 80)
        print(f"Starting experiment for dataset: {dataset_name}")
        print("=" * 80)
        t0 = time.time()
        regularization, _ = get_default_regularization(dataset_name)
        run_experiment(
            dataset_name,
            regularization=regularization,
            k=args.k,
            seed=args.seed,
            verbosity=args.verbosity,
            load_cache=args.load_cache,
            max_train=args.max_train,
            max_test=args.max_test,
            num_processes=args.num_processes
        )
        print(f"Runtime {time.time() - t0:.1f}")
